#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
搜索算法的实现
author: yooongchun@foxmail.com
"""
import time
import copy

import chessboard
import evaluate
from util import Chess, Point, Score


INF = float("inf")

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
class MinMaxSearcher(object):
    """博弈树搜索"""
    def __init__(self, board: chessboard.ChessBoard, neighbor: int=2):
        """初始化
        Args:
            board: 棋盘
            neighbor: 下一步走子的范围
        """
        self._size = board.size
        self._neighbor = neighbor
        self._chessboard = copy.deepcopy(board)
        # 搜索目标根据history最后一个走子来推断，如果没有history，则认为是先手（BLACK）
        # 当推断出是后手时则将棋盘翻转
        self._first = True # 是否先手
        if board.next() == Chess.WHITE:
            self._chessboard.reverse()
            self._first = False
        
        # 置换表，用于存储历史状态
        self._cache = {}
        # 储存中间结果
        self._attack_factor = 1.2 # 攻防系数
        # 最终搜索结果
        self.best_move = None # 最佳走法
        self.used_time = 0 # 搜索时间
        self.total_path = 0 # 遍历的总路径数
        self.cached_path = 0 # 缓存的总路径数
        self.cache_hit = 0 # 缓存命中数

    @property
    def avg_time(self):
        """平均时间"""
        return self.used_time / max(1, self.total_path) 

    @property
    def cache_hit_rate(self):
        """缓存命中率"""
        return self.cache_hit / max(1, self.cached_path)

    def _mark_best_move(self, depth: int, move: Point, score: float):
        """标记最佳走法"""
        # 仅第一层记录下一步最佳走法
        if depth == 0:
            if not hasattr(self, "prev_score"):
                self._prev_score = -float("inf")
            if score > self._prev_score:
                self.best_move = move

    def _cal_score(self, depth: int):
        """计算分数, attack_factor: 进攻系数"""
        last_n = self._chessboard.get_history(depth)
        evaluater = evaluate.Evaluation(self._chessboard, last_n)
        bscore = evaluater.black_score()
        wscore = evaluater.white_score()
        if self._first: # 先手以进攻为主
            return self._attack_factor * bscore - wscore, bscore, wscore
        else: # 后手则以防御为主
            return bscore - self._attack_factor * wscore, bscore, wscore

    def _sort_candidates(self, candidates: list):
        """对候选进行排序"""
        scores = []
        for move in candidates:
            self._chessboard.go(move)
            e = evaluate.Evaluation(self._chessboard, [move])
            self._chessboard.back()
            score = e.black_score() - e.white_score()
            scores.append(score)
        indices = sorted(range(len(candidates)), key=lambda i: scores[i], reverse=True)
        return indices

    def _min_max_search(self, max_depth: int = 0, depth: int=0, alpha: float=-INF, beta: float=INF, path=[]):
        """递归搜索：返回最佳分数
        Args:
            depth: 递归深度
            alpha: 最大节点的上界
            beta: 最小节点的下界
        """
        # 如果深度超出最大深度
        if depth >= max_depth or self._chessboard.is_full():
            self.total_path += 1
            score, bscore, wscore = self._cal_score(depth)
            return score, bscore, wscore
        # 产生新的走法
        candidates = self._chessboard.get_empty(neighbor=self._neighbor, shuffle=True)
        # 深度遍历(DFS)每条路径
        is_max_node = (depth % 2 == 0) # 偶数层节点表示Max节点
        bscore = wscore = None
        indices = self._sort_candidates(candidates)
        for idx in indices:
            move = candidates[idx]
            current = self._chessboard.go(move)
            # 仅当缓存存在且depth相等缓存才可用
            if current in self._cache and self._cache[current][3] == depth:
                self.cache_hit += 1
                score, bscore, wscore = self._cache[current][0:3]
            else: # 否则进行搜索
                score, bscore, wscore = self._min_max_search(max_depth, depth+1, alpha, beta, path+[move])
                self._cache[current] = (score, bscore, wscore, depth+1) # 加入缓存
                self.cached_path += 1
            # 清除当前走法
            self._chessboard.back()
            if is_max_node and score > alpha : # 当前为Max节点
                alpha = score
                self._mark_best_move(depth, move, score)
                # alpha + beta剪枝点
                if score >= beta:
                    break
            elif not is_max_node and score < beta: # 当前节点为MIN节点
                beta = score
                self._mark_best_move(depth, move, score)
                # alpha + beta剪枝点
                if score <= alpha:
                    break
        return (alpha if is_max_node else beta), bscore, wscore

    def _depth_recur(self, max_depth: int):
        """迭代加深: depth逐步加深搜索"""
        depth_score = -INF
        for d in range(2, max_depth+1, 2):
            depth_score, bscore, wscore = self._min_max_search(d)
            if bscore and bscore >= Score.LIVE4:
                return depth_score
            if wscore and wscore >= Score.RUSH4:
                return depth_score
            if self._chessboard.steps <= 10:
                # 10步之内只搜索2层
                break
        return depth_score

    def search(self, max_depth:int=2):
        """最大递归深度"""
        # 空盘面没必要搜索
        if self._chessboard.is_empty():
            size = self._chessboard.size
            score = size // 2
            self.best_move = Point(score, score)
            return score
        # 搜索
        t = time.time()
        score = self._depth_recur(max_depth)
        self.used_time = time.time() - t
        return score if score in (INF, -INF) else round(score)
